// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "main.h"

#include <Eigen/CXX11/Tensor>

using Eigen::Tensor;

template<int DataLayout>
static void
test_dimension_failures()
{
	Tensor<int, 3, DataLayout> left(2, 3, 1);
	Tensor<int, 3, DataLayout> right(3, 3, 1);
	left.setRandom();
	right.setRandom();

	// Okay; other dimensions are equal.
	Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);

	// Dimension mismatches.
	VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
	VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));

	// Axis > NumDims or < 0.
	VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
	VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
}

template<int DataLayout>
static void
test_static_dimension_failure()
{
	Tensor<int, 2, DataLayout> left(2, 3);
	Tensor<int, 3, DataLayout> right(2, 3, 1);

#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
	// Technically compatible, but we static assert that the inputs have same
	// NumDims.
	Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
#endif

	// This can be worked around in this case.
	Tensor<int, 3, DataLayout> concatenation = left.reshape(Tensor<int, 3>::Dimensions(2, 3, 1)).concatenate(right, 0);
	Tensor<int, 2, DataLayout> alternative =
		left
			// Clang compiler break with {{{}}} with an ambiguous error on copy constructor
			// the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H.
			// Solution:
			// either the code should change to
			//  Tensor<int, 2>::Dimensions{{2, 3}}
			// or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}}
			.concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
}

template<int DataLayout>
static void
test_simple_concatenation()
{
	Tensor<int, 3, DataLayout> left(2, 3, 1);
	Tensor<int, 3, DataLayout> right(2, 3, 1);
	left.setRandom();
	right.setRandom();

	Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
	VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
	VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
	VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
	for (int j = 0; j < 3; ++j) {
		for (int i = 0; i < 2; ++i) {
			VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
		}
		for (int i = 2; i < 4; ++i) {
			VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
		}
	}

	concatenation = left.concatenate(right, 1);
	VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
	VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
	VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
	for (int i = 0; i < 2; ++i) {
		for (int j = 0; j < 3; ++j) {
			VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
		}
		for (int j = 3; j < 6; ++j) {
			VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
		}
	}

	concatenation = left.concatenate(right, 2);
	VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
	VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
	VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
	for (int i = 0; i < 2; ++i) {
		for (int j = 0; j < 3; ++j) {
			VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
			VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
		}
	}
}

// TODO(phli): Add test once we have a real vectorized implementation.
// static void test_vectorized_concatenation() {}

static void
test_concatenation_as_lvalue()
{
	Tensor<int, 2> t1(2, 3);
	Tensor<int, 2> t2(2, 3);
	t1.setRandom();
	t2.setRandom();

	Tensor<int, 2> result(4, 3);
	result.setRandom();
	t1.concatenate(t2, 0) = result;

	for (int i = 0; i < 2; ++i) {
		for (int j = 0; j < 3; ++j) {
			VERIFY_IS_EQUAL(t1(i, j), result(i, j));
			VERIFY_IS_EQUAL(t2(i, j), result(i + 2, j));
		}
	}
}

EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
{
	CALL_SUBTEST(test_dimension_failures<ColMajor>());
	CALL_SUBTEST(test_dimension_failures<RowMajor>());
	CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
	CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
	CALL_SUBTEST(test_simple_concatenation<ColMajor>());
	CALL_SUBTEST(test_simple_concatenation<RowMajor>());
	// CALL_SUBTEST(test_vectorized_concatenation());
	CALL_SUBTEST(test_concatenation_as_lvalue());
}
